/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_
#define TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_

#include <cstdint>
#include <deque>
#include <memory>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/data/captured_function.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tsl/platform/refcount.h"
#include "tsl/platform/threadpool.h"

namespace tensorflow {
namespace data {

// Utility class for computing the cardinality of a flat map dataset.
class FlatMapRandomAccessHandler {
 public:
  // Initializes the counter. This will save necessary information from `ctx`.
  // `input_dataset` is the input dataset passed to `flat_map` (not the flat_map
  // dataset). `captured_map_func` is the captured map function.
  FlatMapRandomAccessHandler(OpKernelContext* ctx,
                             const DatasetBase* input_dataset,
                             CapturedFunction& captured_map_func);
  virtual ~FlatMapRandomAccessHandler();
  FlatMapRandomAccessHandler(const FlatMapRandomAccessHandler&) = delete;
  FlatMapRandomAccessHandler& operator=(const FlatMapRandomAccessHandler&) =
      delete;

  // Returns the dataset cardinality.
  absl::StatusOr<int64_t> Cardinality();

  // Returns the cumulative cardinality at the index-th dataset.
  absl::StatusOr<int64_t> CumulativeCardinality(size_t index);

  // Given the flattened element position `element_position`, returns the index
  // of the dataset to which the element belongs.
  absl::StatusOr<int64_t> GetDatasetIndex(size_t element_position);

  // Creates the dataset iterators.
  absl::StatusOr<std::vector<std::unique_ptr<IteratorBase>>> MakeInputIterators(
      IteratorContext* ctx, const DatasetBaseIterator* parent,
      const std::string& prefix);

 private:
  // Computes the cumulative cardinalities.
  absl::StatusOr<std::vector<int64_t>> ComputeCardinalities();

  // Creates the input datasets. Each dataset is the result of applying the map
  // function to one element from the input iterator.
  absl::StatusOr<std::deque<DatasetBase*>> MakeInputDatasets() const;
  absl::StatusOr<DatasetBase*> MakeInputDataset(
      std::vector<Tensor> input_tensors,
      const InstantiatedCapturedFunction& map_func) const;

  const DatasetBase* input_dataset_;
  CapturedFunction& captured_map_func_;

  // The iterator context which bundles together the necessary runtime support
  // to create and get elements from the input dataset.
  std::unique_ptr<IteratorContext> ctx_;
  FunctionLibraryRuntime* flr_;
  std::unique_ptr<FunctionLibraryDefinition> flib_def_;
  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
  std::unique_ptr<thread::ThreadPool> interop_threadpool_;
  std::unique_ptr<FunctionHandleCache> function_handle_cache_;
  std::function<void(std::function<void()>)> runner_;
  ResourceMgr resource_mgr_;
  CancellationManager cancellation_manager_;
  UnboundedThreadPool unbounded_thread_pool_;

  // Input datasets generated by running the map function. Each dataset is the
  // result of applying the map function to one element from the input iterator.
  std::deque<DatasetBase*> input_datasets_;

  // Cumulative cardinalities. Before `ComputeCardinalities` is called, this is
  // an empty vector. After `ComputeCardinalities` is called, the last element
  // is the dataset cardinality.
  absl::StatusOr<std::vector<int64_t>> cumulative_cardinalities_ =
      std::vector<int64_t>{};
};

}  // namespace data
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_
